package com.qz.sql.api.controller;

import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.qz.sql.editor.bean.DatabaseFactoryBean.DatabaseProduct;
import com.qz.sql.editor.bean.DatabaseFactoryBean;
import com.qz.sql.editor.bean.DatabaseRegistrationBean;
import com.qz.sql.editor.dto.*;
import com.qz.sql.editor.entity.DbDatasource;
import com.qz.sql.editor.exception.ConfirmException;
import com.qz.sql.editor.json.DocDbResponseJson;
import com.qz.sql.editor.json.ResponseJson;
import com.qz.sql.editor.mapper.base.BaseMapper;
import com.qz.sql.editor.mapper.mysql.MysqlMapper;
import com.qz.sql.editor.service.DbDatasourceService;
import com.qz.sql.editor.utils.CachePrefix;
import com.qz.sql.editor.utils.CacheUtil;
import com.qz.sql.editor.vo.DatabaseExportVo;
import com.qz.sql.editor.vo.TableColumnVo.TableInfoVo;
import com.qz.sql.editor.vo.TableColumnVo;
import com.qz.sql.editor.vo.TableStatusVo;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.StringUtils;
import org.mybatis.spring.SqlSessionTemplate;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import com.qz.sql.editor.utils.PoiUtils;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.util.*;
import java.util.stream.Collectors;

/**
 * 文档控制器
 *
 * @author
 * @since
 */
@CrossOrigin(allowCredentials = "true", allowedHeaders = "*")
@RestController
@RequestMapping("/qz-sql/doc-db")
public class DatabaseDocController {
	
	@Resource
	DatabaseRegistrationBean databaseRegistrationBean;
	@Resource
	DbDatasourceService dbDatasourceService;

	/**
	 * 获取所有数据源
	 * @return
	 */
	@PostMapping(value = "/getDataSourceList")
	public ResponseJson getDataSourceList() {
		QueryWrapper<DbDatasource> wrapper = new QueryWrapper<>();
		wrapper.eq("yn", 1);
		wrapper.select("id", "name", "group_name", "driver_class_name");
		List<DbDatasource> datasourceList = dbDatasourceService.list(wrapper);
		return DocDbResponseJson.ok(datasourceList);
	}
	
	/**
	 * 获取编辑器所需的所有信息，用于自动补全
	 * 此接口会返回所有库表结构，介意的话请自己手动屏蔽调此接口
	 *
	 * @param sourceId
	 * @return
	 */
	@PostMapping(value = "/getEditorData")
	public ResponseJson getEditorData(Long sourceId) {
//		String cacheKey = CachePrefix.DB_EDITOR_DATA_CACHE + sourceId;
//		Object resultObj = CacheUtil.get(cacheKey);
//		if (resultObj != null) {
//			return DocDbResponseJson.ok(resultObj);
//		}
		BaseMapper baseMapper = this.getBaseMapper(sourceId);
		DatabaseFactoryBean databaseFactoryBean = databaseRegistrationBean.getOrCreateFactoryById(sourceId);
		List<DatabaseInfoDto> dbNameDtoList = baseMapper.getDatabaseList(databaseFactoryBean.getDbName());
		Map<String, List<TableInfoDto>> dbTableMap = new HashMap<>();
		Map<String, List<TableColumnDescDto>> tableColumnsMap = new HashMap<>();
		
		Map<String, List<TableInfoDto>> tableMapList = new HashMap<>();
		// MYSQL可以一次性查询所有库表
		if (databaseFactoryBean.getDatabaseProduct() == DatabaseFactoryBean.DatabaseProduct.MYSQL ) {
			List<TableInfoDto> dbTableList = baseMapper.getTableList(null);
			tableMapList = dbTableList.stream().collect(Collectors.groupingBy(TableInfoDto::getDbName));
		}

		for (DatabaseInfoDto infoDto : dbNameDtoList) {
			List<TableInfoDto> tableInfoDtoList = tableMapList.get(infoDto.getDbName());
			// SQLSERVER必须要库才能查
			if (databaseFactoryBean.getDatabaseProduct() == DatabaseFactoryBean.DatabaseProduct.SQLSERVER) {
				tableInfoDtoList = baseMapper.getTableList(infoDto.getDbName());
			}
			if(databaseFactoryBean.getDatabaseProduct() == DatabaseProduct.POSTGRESQL){
				tableInfoDtoList = baseMapper.getTableList(infoDto.getDbName());
			}
			if (CollectionUtils.isEmpty(tableInfoDtoList)) {
				continue;
			}
			dbTableMap.put(infoDto.getDbName(), tableInfoDtoList);
			Map<String, List<TableColumnDescDto>> columnDescDtoMap = new HashMap<>();
			if (dbNameDtoList.size() <= 10 || Objects.equals(databaseFactoryBean.getDbName(), infoDto.getDbName())) {
				List<TableColumnDescDto> columnDescDto = baseMapper.getTableColumnList(infoDto.getDbName(), null);
				columnDescDtoMap = columnDescDto.stream().collect(Collectors.groupingBy(TableColumnDescDto::getTableName));
			}
			List<TableColumnDescDto> pgDescDtoList =new ArrayList<>();
			List<TableColumnDescDto> sqlserverDtoList = new ArrayList<>();
			for (TableInfoDto tableInfoDto : tableInfoDtoList) {
				List<TableColumnDescDto> descDtoList = columnDescDtoMap.get(tableInfoDto.getTableName());
				if(databaseFactoryBean.getDatabaseProduct() == DatabaseProduct.POSTGRESQL){
					List<TableColumnDescDto> tableColumnList = baseMapper.getTableColumnList(infoDto.getDbName(), tableInfoDto.getTableName());
					System.out.println(tableColumnList.size());
					pgDescDtoList.addAll(baseMapper.getTableColumnList(infoDto.getDbName(), tableInfoDto.getTableName()));

				} else if(databaseFactoryBean.getDatabaseProduct() == DatabaseProduct.SQLSERVER){
					List<TableColumnDescDto> tableColumnList = baseMapper.getTableColumnList(infoDto.getDbName(), tableInfoDto.getTableName());
					System.out.println(tableColumnList.size());
					sqlserverDtoList.addAll(baseMapper.getTableColumnList(infoDto.getDbName(), tableInfoDto.getTableName()));

				} else{
				if (CollectionUtils.isNotEmpty(descDtoList)) {
					tableColumnsMap.put(tableInfoDto.getTableName(), descDtoList);
				}}
			}
			if(databaseFactoryBean.getDatabaseProduct() == DatabaseProduct.POSTGRESQL){
				tableColumnsMap.put(infoDto.getDbName(), pgDescDtoList);
			}
			if(databaseFactoryBean.getDatabaseProduct() == DatabaseProduct.SQLSERVER){
				tableColumnsMap.put(infoDto.getDbName(), sqlserverDtoList);
			}
		}
		Map<String, Object> dbResultMap = new HashMap<>();
		dbResultMap.put("db", dbNameDtoList);
		dbResultMap.put("table", dbTableMap);
		dbResultMap.put("column", tableColumnsMap);
		// 缓存10分钟，如果10分钟内库里面增删改了表或字段，则提示不出来
//		CacheUtil.put(cacheKey, dbResultMap, 6000);
		return DocDbResponseJson.ok(dbResultMap);
	}
	
	@PostMapping(value = "/getTableDdl")
	public ResponseJson getTableDdl(Long sourceId, String dbName, String tableName) {
		if(!dbName.contains("`")){
			dbName="`"+dbName+"`";
		}
		DatabaseFactoryBean databaseFactoryBean = databaseRegistrationBean.getOrCreateFactoryById(sourceId);
		// 不同数据源类型获取方式不一致
		if (Objects.equals(DatabaseProduct.MYSQL, databaseFactoryBean.getDatabaseProduct())) {
			BaseMapper baseMapper = this.getViewAuthBaseMapper(sourceId);
			Map<String, String> dataMap = baseMapper.getTableDdl(dbName, tableName);
			return DocDbResponseJson.ok(dataMap.get("Create Table"));
		}
		if (Objects.equals(DatabaseProduct.ORACLE, databaseFactoryBean.getDatabaseProduct())) {
			BaseMapper baseMapper = this.getViewAuthBaseMapper(sourceId);
			Map<String, String> dataMap = baseMapper.getTableDdl(dbName, tableName);
			return DocDbResponseJson.ok(dataMap.get("create_table"));
		}
		if (Objects.equals(DatabaseProduct.POSTGRESQL, databaseFactoryBean.getDatabaseProduct())) {
			//根据表字段信息，拼接建表语句
			//查询字段西悉尼
			dbName = dbName.replace("`", "");
			TableColumnVo tableColumnVo = this.getTableColumnVo(databaseFactoryBean, dbName, tableName);
			List<TableColumnDescDto> columnList = tableColumnVo.getColumnList();
			String sql ="CREATE TABLE \"adsa\".\"test\" (\n";
			for(TableColumnDescDto tmp:columnList){
				String name = tmp.getName();
				String type = tmp.getType();
				sql =sql+"\""+name+"\"" +type+"COLLATE \"pg_catalog\".\"default\","+"\n";

			}
			 sql = sql.substring(0,sql.length() - 2)+"\n);";
			return DocDbResponseJson.ok(sql);
		}

		if (Objects.equals(DatabaseProduct.SQLSERVER, databaseFactoryBean.getDatabaseProduct())) {
			//根据表字段信息，拼接建表语句
			//查询字段西悉尼
			TableColumnVo tableColumnVo = this.getTableColumnVo(databaseFactoryBean, dbName, tableName);
			List<TableColumnDescDto> columnList = tableColumnVo.getColumnList();
			String sql ="CREATE TABLE [dbo].["+databaseFactoryBean.getDbName()+"] (\n";
			for(TableColumnDescDto tmp:columnList){
				String name = tmp.getName();
				String type = tmp.getType();
				sql =sql+"\""+name+"\"" +type+"("+tmp.getLength()+") COLLATE SQL_Latin1_General_CP1_CI_AS  NULL"+"\n";

			}
			sql = sql.substring(0,sql.length() - 2)+"\n);";
			return DocDbResponseJson.ok(sql);
		}
		return DocDbResponseJson.ok("暂未支持的数据库类型");
	}
	
	@PostMapping(value = "/getDatabaseList")
	public ResponseJson getDatabaseList(Long sourceId) {
		BaseMapper baseMapper = this.getViewAuthBaseMapper(sourceId);
		QueryWrapper<DbDatasource> wrapper = new QueryWrapper<>();
		wrapper.eq("id", sourceId);
		DbDatasource dbDatasource = dbDatasourceService.getOne(wrapper);
		DatabaseFactoryBean databaseFactoryBean = databaseRegistrationBean.getOrCreateFactoryById(sourceId);
		if(dbDatasource.getSourceUrl().contains("jdbc:mysq")){
			String[] urlParamArr = dbDatasource.getSourceUrl().split("\\?");
			String[] urlDbNameArr = urlParamArr[0].split("/");
			if (urlDbNameArr.length >= 2) {
				List<DatabaseInfoDto> dbNameDtoList = baseMapper.getDatabaseList(databaseFactoryBean.getDbName());
				return DocDbResponseJson.ok(dbNameDtoList);
			}else{
				List<DatabaseInfoDto> dbNameDtoList = baseMapper.getDatabaseList(databaseFactoryBean.getDbName());
				return DocDbResponseJson.ok(dbNameDtoList);
			}
		}else{
			List<DatabaseInfoDto> dbNameDtoList = baseMapper.getDatabaseList(databaseFactoryBean.getDbName());
			return DocDbResponseJson.ok(dbNameDtoList);
		}

	}
	
	@PostMapping(value = "/getTableStatus")
	public ResponseJson getTableStatus(Long sourceId, String dbName, String tableName) {
		try{
			BaseMapper baseMapper = this.getViewAuthBaseMapper(sourceId);
			TableStatusVo tableStatusVo = baseMapper.getTableStatus(dbName, tableName);
			DatabaseFactoryBean factoryBean = databaseRegistrationBean.getOrCreateFactoryById(sourceId);
			tableStatusVo.setDbType(factoryBean.getDatabaseProduct().name().toLowerCase());
			return DocDbResponseJson.ok(tableStatusVo);
		}catch (Exception e){
			return DocDbResponseJson.ok();
		}
	}
	
	@PostMapping(value = "/getTableList")
	public ResponseJson getTableList(Long sourceId, String dbName) {
		BaseMapper baseMapper = this.getBaseMapper(sourceId);
		List<TableInfoDto> dbTableList = baseMapper.getTableList(dbName);
		return DocDbResponseJson.ok(dbTableList);
	}
	
	@PostMapping(value = "/getTableColumnList")
	public ResponseJson getTableColumnList(Long sourceId, String dbName, String tableName) {
		DatabaseFactoryBean databaseFactoryBean = databaseRegistrationBean.getOrCreateFactoryById(sourceId);
		if (databaseFactoryBean == null) {
			return DocDbResponseJson.warn("未找到对应的数据库连接");
		}
		TableColumnVo tableColumnVo = this.getTableColumnVo(databaseFactoryBean, dbName, tableName);
		return DocDbResponseJson.ok(tableColumnVo);
	}
	
	@PostMapping(value = "/getTableColumnDescList")
	public ResponseJson getTableColumnDescList(Long sourceId, String tableName) {
		BaseMapper baseMapper = this.getViewAuthBaseMapper(sourceId);
		List<TableColumnDescDto> columnDescDto = baseMapper.getTableColumnDescList(tableName);
		return DocDbResponseJson.ok(columnDescDto);
	}
	
	@PostMapping(value = "/getTableAndColumnBySearch")
	public ResponseJson getTableAndColumnBySearch(Long sourceId, String dbName, String searchText,Boolean searchTable,Boolean searchColumn) {
		BaseMapper baseMapper = this.getViewAuthBaseMapper(sourceId);
		if (StringUtils.isBlank(searchText)) {
			return DocDbResponseJson.ok();
		}
		searchText = "%" + searchText + "%";

		String tableText="     ";
		String clumnText="     ";
		if(searchTable){
			tableText=searchText;
		}
		if(searchColumn){
			clumnText=searchText;
		}
		List<QueryTableColumnDescDto> columnDescDto = baseMapper.getTableAndColumnBySearch(dbName, searchText,tableText,clumnText);
		return DocDbResponseJson.ok(columnDescDto);
	}
	
	@PostMapping(value = "/getTableDescList")
	public ResponseJson getTableDescList(Long sourceId, String dbName, String tableName) {
		BaseMapper baseMapper = this.getViewAuthBaseMapper(sourceId);
		List<TableDescDto> columnDescDto = baseMapper.getTableDescList(dbName, tableName);
		return DocDbResponseJson.ok(columnDescDto);
	}
	
	@PostMapping(value = "/updateTableDesc")
	public ResponseJson updateTableDesc(Long sourceId, String dbName, String tableName, String newDesc) {
//		this.judgeAuth(sourceId, DbAuthType.DESC_EDIT.getName(), "没有修改该表注释的权限");
		BaseMapper baseMapper = this.getBaseMapper(sourceId);
		baseMapper.updateTableDesc(dbName, tableName, newDesc);
		return DocDbResponseJson.ok();
	}
	
	@PostMapping(value = "/updateTableColumnDesc")
	public ResponseJson updateTableColumnDesc(Long sourceId, String dbName, String tableName, String columnName, String newDesc) {
		BaseMapper baseMapper = this.getBaseMapper(sourceId);
		ColumnInfoDto columnInfo = null;
		MysqlMapper mysqlMapper = databaseRegistrationBean.getBaseMapper(sourceId, MysqlMapper.class);
		if (mysqlMapper != null) {
			columnInfo = mysqlMapper.getColumnInfo(dbName, tableName, columnName);
			String isNullable = Optional.ofNullable(columnInfo.getIsNullable()).orElse("");
			columnInfo.setIsNullable("yes".equalsIgnoreCase(isNullable) ? "null" : "not null");
			String columnDefault = columnInfo.getColumnDefault();
			if (StringUtils.isNotBlank(columnDefault)) {
				columnInfo.setColumnDefault("DEFAULT " + columnDefault);
			} else {
				columnInfo.setColumnDefault("");
			}
			String extra = columnInfo.getExtra();
			columnInfo.setExtra(StringUtils.isBlank(extra) ? "" : extra);
		}
		baseMapper.updateTableColumnDesc(dbName, tableName, columnName, newDesc, columnInfo);
		return DocDbResponseJson.ok();
	}
	
	@PostMapping(value = "/exportDatabase")
	public ResponseJson exportDatabase(HttpServletResponse response, Long sourceId, String dbName, String tableNames, Integer exportType) {
//		this.judgeAuth(sourceId, DbAuthType.VIEW.getName(), "没有查看该库表信息的权限");
		if (StringUtils.isBlank(tableNames)) {
			return DocDbResponseJson.warn("请选择需要导出的表");
		}
		DatabaseFactoryBean databaseFactoryBean = databaseRegistrationBean.getOrCreateFactoryById(sourceId);
		if (databaseFactoryBean == null) {
			return DocDbResponseJson.warn("未找到对应的数据库连接");
		}
		List<TableInfoVo> tableList = new LinkedList<>();
		Map<String, List<TableColumnDescDto>> columnList = new HashMap<>();
		String[] tableNameArr = tableNames.split(",");
		for (String tableName : tableNameArr) {
			if (StringUtils.isBlank(tableName)) {
				continue;
			}
			TableColumnVo tableColumnVo = this.getTableColumnVo(databaseFactoryBean, dbName, tableName);
			columnList.put(tableName, tableColumnVo.getColumnList());
			tableList.add(tableColumnVo.getTableInfo());
		}
		DatabaseExportVo exportVo = new DatabaseExportVo();
		exportVo.setColumnList(columnList);
		exportVo.setTableList(tableList);
		try {
			if (Objects.equals(exportType, 1)) {
				PoiUtils.exportByText(exportVo, response);
			} else if (Objects.equals(exportType, 2)) {
				PoiUtils.exportByXlsx(exportVo, response);
			} else if (Objects.equals(exportType, 3)) {
				PoiUtils.exportByDocx(dbName, exportVo, response);
			} else {
				return DocDbResponseJson.error("导出失败：请先选择导出类型");
			}
		} catch (Exception e) {
			e.printStackTrace();
			return DocDbResponseJson.error("导出失败：" + e.getMessage());
		}
		return DocDbResponseJson.ok();
	}
	
	private TableColumnVo getTableColumnVo(DatabaseFactoryBean databaseFactoryBean, String dbName, String tableName) {
		SqlSessionTemplate sessionTemplate = databaseFactoryBean.getSqlSessionTemplate();
		BaseMapper baseMapper = sessionTemplate.getMapper(BaseMapper.class);
		List<TableColumnDescDto> columnDescDto = baseMapper.getTableColumnList(dbName, tableName);
		// SQLSERVER要单独查字段注释
		if (databaseFactoryBean.getDatabaseProduct() == DatabaseProduct.SQLSERVER) {
			List<TableColumnDescDto> columnDescList = baseMapper.getTableColumnDescList(tableName);
			Map<String, TableColumnDescDto> columnMap = columnDescDto.stream().collect(Collectors.toMap(TableColumnDescDto::getName, val -> val));
			// 字段注释
			for (TableColumnDescDto descDto : columnDescList) {
				TableColumnDescDto tempDesc = columnMap.get(descDto.getName());
				if (tempDesc != null) {
					tempDesc.setDescription(descDto.getDescription());
				}
			}
		}
		TableColumnVo tableColumnVo = new TableColumnVo();
		tableColumnVo.setColumnList(columnDescDto);
		// 表注释
		TableColumnVo.TableInfoVo tableInfoVo = new TableColumnVo.TableInfoVo();
		List<TableDescDto> tableDescList = baseMapper.getTableDescList(dbName, tableName);
		String description = null;
		if (tableDescList.size() > 0) {
			TableDescDto descDto = tableDescList.get(0);
			description = descDto.getDescription();
		}
		description = Optional.ofNullable(description).orElse("");
		tableInfoVo.setDescription(description);
		tableInfoVo.setTableName(tableName);
		tableColumnVo.setTableInfo(tableInfoVo);
		return tableColumnVo;
	}
	

	
	/**
	 * 获取BaseMapper
	 *
	 * @author
	 */
	private BaseMapper getBaseMapper(Long sourceId) {
		BaseMapper baseMapper = databaseRegistrationBean.getBaseMapperById(sourceId);
		if (baseMapper == null) {
			throw new ConfirmException("未找到对应的数据库连接");
		}
		return baseMapper;
	}
	
	/**
	 * 判断查看权和获取BaseMapper
	 *
	 * @author
	 */
	private BaseMapper getViewAuthBaseMapper(Long sourceId) {
		return this.getBaseMapper(sourceId);
	}
	

}

